course: (ML Series) Decision Trees & Random Forests

By Julien Hernandez Lallement, 2021-02-12, in category Course

machine learning, python

Decision Trees

General Background

Decision Tress (DM) are supervised learning techniques. In a nutshell, they are flowcharts, that can be used for both regression and classification. I personally like decision trees because they are easy to explain and quite intuitive (you can visualize them easily). However, they tend to be sensitive in changes in the traning dataset and can be slow depending on the hyperparameter used.

Use case

DM can be used to:

  • Classifying data based on features (categorical or continous)
  • Predicting numerical value based on a series of independent variables

Theoretical Background

Decision trees are quite easy to understand. Let's visualize one, from this blog.

Understanding a tree is quite simple :)

In [5]:
from IPython.display import Image
PATH = "/home/julien/website/content/images/2020_12_DM/"
Image(filename = PATH + "decision-tree.png", width=500, height=500)
Out[5]:

If you look at the tree above, you simply have to follow the different pathss to get to a node. At the node, you then take the most common class (or average value for regression). And so on. I found that this was maybe the biggest benefits of DMs: they are so easy to explain and understand. This is a non negligeable advantage when it comes to explaining your model to non tech stakeholders, that might be somehow resistant to new machine-driven decisions.

Building a tree

As you can see, the tree is built upside down. Each condition is represented by the question, or node, which splits the features. From each condition, the branches wil emerge, representing possible answers. When you get at the end of a branch where no more conditions are present, then you reached a leaf.

As mentioned before, while DMs are typically referred to for classification problems (whether or not to eat a french cheese based on its attributes), you can also use DMs for regression approaches.

One important concept in DMs if purity. Think about it: what you aim for are pure leafs. In other words, you would want all the training examples of a given class to land in that leaf. During model testing, landing on such leaf should get you the highest level of confidence that you are in the right spot.

Defining impurity

1. Gini

While there are other measures (see below), impurity is typically defined as Gini's impurity, given by the following formula:

$$Gini_{i} = 1 - \sum_{k=1}^{n}{p_{i,k}^2}$$

where $i$ is the node of interest, $n$ is the total number of categories in your dataset, and $p_{i,k}$ is the fraction of class $k$ in node $i$. The lower the impurity, the more homogeneous are your nodes. As a result, a gini = 0 for a given node means that all the examples belong to the same clategory. Alternatively, impurity is the highest when there is an equal probability of being in each category.

Imagine your are tying to predict the gender of people, i.e., whether someone is male or female. You could use height as a feature, and forch the separation at some height value. In this node, I may find 90 examples of which 65 are female and 25 are male. Gini is then:

$$1 - (65/90)^2 - (25/90)^2 = 0.44$$

In other terms, this value tells you a good a node is. We first look the most promising node, and start growing the tree top-down from there. This means that not all trees are considered, only the ones starting with the purest node.

Therefore, the algorithm looks at all the features and all the splitting points for our features. It will take the weighted averaged of the Gini impurity for each of the node created by the split. The split that created the lowest weighted averaged impurity is the selected split. The algorithm then mvoes on to the next feature split.

2. Entropy

As mentioned above, there are other measures of impurity that can be used with decision trees. Probably the second most popular one is Entropy. To put it simply, entropy is a measure of disorders. In other words, high level of entropy are equivalent to low level of purity. Entropy typically ranges between [0;1] altough depending on your dataset, you might end up with values > 1. In any case, the goal here is to decrease the value of entropy in your model.

Entropy can be written in the following mathematical terms:

$$E(S) = \sum_{k=1}^{c}{p_i}{log_2}{p_i}$$

where $p_i$ is the probability of occurence of an element/class ‘i’ in our data.

Let's assume we have to classes in our dataset, which contains 1000 data points. 250 belong to the first class and the remaining 650 belong to the second class.

This gives us:

$$ - \frac{25}{100} * log_2 \frac{25}{100} - \frac{65}{100} * log_2 \frac{65}{100} \approx 0.90$$

This would be considered high level of entropy, which your tuning could aim at decreasing

Usually, which measure of impurity you choose doesn't matter too much. Gini is slightly faster to computer, while entropy produces slightly more balanced trees.

3. Regression

Decision trees can also be used in regression problems. Here, instead of using Gini impurity or Entropy, you would use mean squared error (MSE). The MSE is calculated at the node level by determining the node predicted value as the average value of the points in that node.

When to stop?

One important issue with decision trees is to know when to stop the model, which would otherwise overfit the data quite strongly. To avoid that, decision trees have many hyper-parameters to control when to stop growing the tree. These can be found in the SciKit Learn version of the model, and include hyper-parameters such as max depth, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, and max_leaf_nodes. Note that while decision trees are very prone to overfitting (if you let them grow to deep), decreasing the depth can result in a decreased/increased variance and bias, respectively. It is generally recommended to use cross-validation to perform hyper-parameter selection.

Practical Demonstration

I decided not to present a practical demonstration of simple decision trees here, since they are rarely used on their own in real world situations. Instead, machine learning engineers typically like to use them as ensemble, that is a combination of multiple models. Put together, these ensembles can have a much higher predictive power, hence the preference for their use (they also have lower variance than single models). Using multiple models together makes intuitive sense though: if you want to know whether a dentist is good, you typically ask more than one person ;)

However, it should be noted that one looses model interpretability by using ensembles. By extension, we loose the nice visualizations that one can produce with single decision trees. I typically run single trees to explain non tech stakeholders what is happening in the background, and then move on to random forests for actual predictions.

It should be noted that although you loose some explanatory power by using random forests, you still can do some nice feature importance calculations, by taking the average reduction in impurity measure across your trees. Say one decision tree reduces impurity by a factor 5 by splitting on feature X, and another tree reduces impurity by a factor 15 by using the same feature X, one could take the average reduction in impurity as a feature importance estimation.

Random forest models train a series of decision trees that differ in several parameters and then combine them via voting. Note that ensemble models do not necessarily use the same model, but can combine different regressors (SVM, Linear, etc...).

As far as I know (but fact checking needed), random forests use decision trees with bagging. This means that each decision tree model randomly samples from the training data with replacement. Since you sample with replacement, each model is exposed to approximately 63% of the data (resampling without replacement is called pasting, but this approach is rarely used). Nice thing with this approach is that you can run fast predictions since training and predictions can be run in parallel. The remaining data points are called out-of-bag samples. These samples can be used for evaluation without actually needing a cross-validation set (this can be useful with very large data where training multiple cross-validated models can be costly).

Moreover, random forests uses models that randomly sample features in the data set (an approach called random patches).

Random Forest with SciKit Learn

In [1]:
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
import numpy as np
import pandas as pd
%matplotlib inline

Let's split the data into test and train sets

In [11]:
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.20, random_state=42)
In [25]:
unique, counts = np.unique(y_train, return_counts=True)
print(np.asarray((unique, counts)).T)
[[  0 169]
 [  1 286]]

It seems we have quite some unbalanced classes in the dataset. As a result, I will use the F1 score to evaluate my model.

In [26]:
from sklearn.metrics import f1_score, classification_report

We can test different levels of hyperparameters, to illustrate of the tuning might help in getting a better prediction. We will tune the following parameters: n_estimators, max_depth & class_weights

In [14]:
# Number of trees in the model
n_estimators = [500, 1000, 2500, 3000]
In [16]:
# Maximum depth of the tree, after which the model stops
max_depth = [1, 3, 5, 8, 10]

class_weights parameters determines the weight associated with each data class. As stated in the SciKit Learn doc, None supposes that all classes have weight one. [...]. The “balanced” mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as n_samples / (n_classes np.bincount(y))*.

In [17]:
class_weights = ['balanced', None]

I put these different levels of hyperparameter together to use GridSearch

In [42]:
# Number of trees in the model
mod_params = {'n_estimators': [500, 1000, 2500, 3000],
              'max_depth': [1, 3, 5, 8, 10],
              'class_weight': ['balanced', None]}
In [34]:
# Load model
mod = RandomForestClassifier()
# Build Grid
clf = GridSearchCV(mod, mod_params)
# Fit
clf.fit(X_train, y_train)
In [40]:
pd.DataFrame(clf.cv_results_).T
Out[40]:
0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
mean_fit_time 0.50935 1.01997 2.70075 3.05275 0.61177 1.25383 3.09663 3.51246 0.626128 1.22403 ... 3.03512 3.6356 0.615584 1.2277 3.07945 3.70168 0.622788 1.22823 3.07964 3.69594
std_fit_time 0.0188284 0.0267305 0.15817 0.0460558 0.0185718 0.0182008 0.113879 0.0817199 0.0272952 0.00712271 ... 0.0116173 0.0301844 0.0147044 0.0177375 0.0308206 0.030183 0.012782 0.0211802 0.0539122 0.0332577
mean_score_time 0.0324838 0.0657083 0.165283 0.209511 0.0373786 0.0654385 0.178051 0.192293 0.0327048 0.0647789 ... 0.159296 0.208122 0.0324129 0.0636556 0.15952 0.191036 0.0322443 0.064167 0.159398 0.191194
std_score_time 0.00135676 0.0032327 0.00984867 0.0172536 0.00510942 0.000941093 0.0201501 0.000719859 0.000424297 0.000112396 ... 0.000881087 0.021166 0.000253351 0.000200346 0.000334815 0.000891689 0.000293512 0.000676815 0.00127708 0.000806503
param_class_weight balanced balanced balanced balanced balanced balanced balanced balanced balanced balanced ... None None None None None None None None None None
param_max_depth 1 1 1 1 3 3 3 3 5 5 ... 5 5 8 8 8 8 10 10 10 10
param_n_estimators 500 1000 2500 3000 500 1000 2500 3000 500 1000 ... 2500 3000 500 1000 2500 3000 500 1000 2500 3000
params {'class_weight': 'balanced', 'max_depth': 1, '... {'class_weight': 'balanced', 'max_depth': 1, '... {'class_weight': 'balanced', 'max_depth': 1, '... {'class_weight': 'balanced', 'max_depth': 1, '... {'class_weight': 'balanced', 'max_depth': 3, '... {'class_weight': 'balanced', 'max_depth': 3, '... {'class_weight': 'balanced', 'max_depth': 3, '... {'class_weight': 'balanced', 'max_depth': 3, '... {'class_weight': 'balanced', 'max_depth': 5, '... {'class_weight': 'balanced', 'max_depth': 5, '... ... {'class_weight': None, 'max_depth': 5, 'n_esti... {'class_weight': None, 'max_depth': 5, 'n_esti... {'class_weight': None, 'max_depth': 8, 'n_esti... {'class_weight': None, 'max_depth': 8, 'n_esti... {'class_weight': None, 'max_depth': 8, 'n_esti... {'class_weight': None, 'max_depth': 8, 'n_esti... {'class_weight': None, 'max_depth': 10, 'n_est... {'class_weight': None, 'max_depth': 10, 'n_est... {'class_weight': None, 'max_depth': 10, 'n_est... {'class_weight': None, 'max_depth': 10, 'n_est...
split0_test_score 0.923077 0.901099 0.912088 0.912088 0.934066 0.923077 0.923077 0.934066 0.956044 0.956044 ... 0.956044 0.956044 0.956044 0.978022 0.978022 0.978022 0.967033 0.978022 0.967033 0.967033
split1_test_score 0.901099 0.901099 0.901099 0.901099 0.923077 0.934066 0.923077 0.923077 0.945055 0.934066 ... 0.945055 0.945055 0.945055 0.934066 0.945055 0.945055 0.945055 0.945055 0.945055 0.945055
split2_test_score 0.934066 0.934066 0.934066 0.934066 0.978022 0.978022 0.978022 0.978022 0.978022 0.978022 ... 0.978022 0.978022 0.978022 0.978022 0.978022 0.978022 0.978022 0.978022 0.978022 0.978022
split3_test_score 0.901099 0.923077 0.912088 0.912088 0.945055 0.945055 0.945055 0.945055 0.956044 0.956044 ... 0.956044 0.956044 0.956044 0.956044 0.956044 0.956044 0.956044 0.956044 0.956044 0.956044
split4_test_score 0.934066 0.934066 0.934066 0.934066 0.956044 0.956044 0.956044 0.956044 0.967033 0.967033 ... 0.945055 0.945055 0.934066 0.945055 0.945055 0.945055 0.945055 0.956044 0.956044 0.956044
mean_test_score 0.918681 0.918681 0.918681 0.918681 0.947253 0.947253 0.945055 0.947253 0.96044 0.958242 ... 0.956044 0.956044 0.953846 0.958242 0.96044 0.96044 0.958242 0.962637 0.96044 0.96044
std_test_score 0.0149062 0.0149062 0.0131868 0.0131868 0.0189062 0.0189062 0.0208502 0.0189062 0.0112066 0.0145786 ... 0.0120379 0.0120379 0.0145786 0.0175824 0.0149062 0.0149062 0.0128153 0.0131868 0.0112066 0.0112066
rank_test_score 34 33 34 34 27 27 32 27 3 13 ... 19 19 24 13 3 3 13 1 3 3

16 rows × 40 columns

In [41]:
# Get best estimator
clf.best_estimator_
Out[41]:
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight='balanced',
                       criterion='gini', max_depth=8, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=1000,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

I re run the winning model, by entering the parameters manually. Note that SciKit Learn as a refit parameter in GridSearch that would do this step in the loop

In [43]:
from sklearn.metrics import classification_report
In [44]:
clf = RandomForestClassifier(n_estimators=1000, max_depth=8, class_weight='balanced')
clf.fit(X_train, y_train)
test_predictions = clf.predict(X_test)
print("Test Classification Report:")
print(classification_report(y_test, test_predictions))
Test Classification Report:
              precision    recall  f1-score   support

           0       0.98      0.93      0.95        43
           1       0.96      0.99      0.97        71

    accuracy                           0.96       114
   macro avg       0.97      0.96      0.96       114
weighted avg       0.97      0.96      0.96       114

We mentioned before that despite a loss in explainability due to ensemble models, random forests still allow us to rank feature by their importance in the predictive power.

Let's have a quick look at feature importance, to visualize which feature have play the most important role in the preditions

In [45]:
feature_imp = sorted(list(zip(data.feature_names, clf.feature_importances_)), key=lambda x: x[1], reverse=True)
In [46]:
pd.Series([x[1] for x in feature_imp], index=[x[0] for x in feature_imp]).plot(kind='bar')
Out[46]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f311a9599d0>

Final words

That's it for the decision tree and random forest! Let me know if you have feedback or constructive comments!